# Functions to construct baseline vote shares ---------
# Impute vote shares by shifting & scaling vote share & turnout at the county level
impute_16 = function(state, map, d) {
  d = filter(d, .data$state == .env$state) |>
    select(state, county, turnout_16:mult)
  as_tibble(map) |>
    select(state, county, pre_20_dem_bid, pre_20_rep_tru) |>
    left_join(d, by=c("state", "county")) |>
    mutate(turn_20 = pre_20_dem_bid + pre_20_rep_tru,
           ldem_20 = coalesce(qlogis(pre_20_dem_bid / turn_20), 0),
           pre_16_dem_cli = plogis(ldem_20 - shift) * turn_20/mult,
           pre_16_rep_tru = (1 - plogis(ldem_20 - shift)) * turn_20/mult) |>
    select(state, county, starts_with("pre_16"), starts_with("pre_20"))
}
impute_20 = function(state, map, d) {
  d = filter(d, .data$state == .env$state) |>
    select(state, county, turnout_16:mult)
  as_tibble(map) |>
    select(state, county, pre_16_dem_cli, pre_16_rep_tru) |>
    left_join(d, by=c("state", "county")) |>
    mutate(turn_16 = pre_16_dem_cli + pre_16_rep_tru,
           ldem_16 = coalesce(qlogis(pre_16_dem_cli / turn_16), 0),
           pre_20_dem_bid = plogis(ldem_16 + shift) * turn_16*mult,
           pre_20_rep_tru = (1 - plogis(ldem_16 + shift)) * turn_16*mult) |>
    select(state, county, starts_with("pre_16"), starts_with("pre_20"))
}

# Average vote share and turnout (imputing as necessary) to build precinct baseline
build_state_baseline = function(map) {
  state = map$state[1]
  has_16 = "pre_16_dem_cli" %in% colnames(map)
  has_20 = "pre_20_dem_bid" %in% colnames(map)

  # fill in gaps
  if (state == "SD") { # only statewide, so do by hand
    map$pre_20_dem_bid = 150471
    map$pre_20_rep_tru = 261043
  } else if (!has_16 || !has_20) {
    # prepare county data
    raw_cty = read_rds(here("data/medsl_county_president/county_pres.rds"))
    # patch Salt Lake City manually
    patch_ut = tibble(state="UT", county="49035", cty_20_dem_bid=289906, cty_20_rep_tru=230174)

    d = raw_cty |>
      filter(str_detect(candidate, "(CLINTON|BIDEN|TRUMP)"),
             state_po != "DC", county_fips != "51515",
             !(state_po == "UT" & mode != "TOTAL")) |>
      transmute(state = state_po,
                county = county_fips,
                year = str_sub(as.character(year), 3),
                party = str_sub(str_to_lower(party), 1, 3),
                candidate = str_remove(candidate, " JR"),
                candidate = str_sub(str_to_lower(word(candidate, -1)), 1, 3),
                votes = candidatevotes) |>
      group_by(state, county, year, party, candidate) |>
      summarize(votes = sum(votes), .groups="drop") |>
      pivot_wider(names_from=c("year", "party", "candidate"), names_sep="_",
                  names_prefix="cty_", values_from=votes) |>
      drop_na(county) |>
      rows_update(patch_ut, by=c("state", "county")) |>
      mutate(turnout_16 = cty_16_dem_cli + cty_16_rep_tru,
             turnout_20 = cty_20_dem_bid + cty_20_rep_tru,
             dshare_16 = cty_16_dem_cli / turnout_16,
             dshare_20 = cty_20_dem_bid / turnout_20,
             shift = qlogis(dshare_20) - qlogis(dshare_16),
             mult = turnout_20 / turnout_16)

    if (has_16 && !has_20) {
      map = impute_20(state, map, d)
    } else if (!has_16 && has_20) {
      map = impute_16(state, map, d)
    }
  }

  turnout_16 = with(map, pre_16_dem_cli + pre_16_rep_tru)
  turnout_20 = with(map, pre_20_dem_bid + pre_20_rep_tru)
  dshare_16 = map$pre_16_dem_cli / turnout_16
  dshare_20 = map$pre_20_dem_bid / turnout_20
  # plug holes
  dshare_16 = coalesce(dshare_16, dshare_20, 0.5)
  dshare_20 = coalesce(dshare_20, dshare_16, 0.5)

  base_turnout = sqrt(turnout_16 * turnout_20) # geometric mean
  base_dshare = 0.5 * (dshare_16 + dshare_20)

  ndv = round(base_turnout * base_dshare, 1)
  nrv = round(base_turnout - ndv, 1)

  # state_i is the row number to match to the plans file later if things get shuffled
  tibble(state=state, state_i=1:nrow(map), ndv=ndv, nrv=nrv)
}

# Load states and build baseline -------

if (!file_exists(here("data/precinct_baseline.csv"))) {
  d_baseline = map_dfr(cli_progress_along(state.abb), function(i) {
    load_50state_map(state.abb[i]) |>
      build_state_baseline()
  })

  # save baseline
  write_csv(d_baseline, here("data/precinct_baseline.csv"))
} else {
  d_baseline <- read_csv(here("data/precinct_baseline.csv"), show_col_types=FALSE)
}


# Fit election model --------
# load data
# keep elections contested by exactly 1 candidate from each party
d_house = read_rds(here("data/medsl_house/house.rds")) |>
  group_by(year, state, district) |>
  filter(sum(str_starts(party, "DEMOCRAT"), na.rm = TRUE) == 1,
         sum(str_starts(party, "REPUBLICAN"), na.rm = TRUE) == 1) %>%
  ungroup() %>%
  filter(str_starts(party, "(DEMOCRAT|REPUBLICAN)")) %>%
  transmute(year = year,
            state = state_po,
            state_yr = str_c(state, "-", year),
            district = str_c(state, "-", district, "-", floor((year-1)/10)*10),
            party = str_to_lower(str_sub(party, 1, 3)),
            votes = candidatevotes) |>
  pivot_wider(names_from=party, values_from=votes) %>%
  mutate(ldshare = log(dem) - log(rep))


# Fit and save election model
if (!file_exists(here("data/election_model.rds"))) {
  # partial-pooling ANOVA
  library(brms)
  m_distr = brm(ldshare ~ (1 | district) + (1 | year),
                 data=d_house, family=student,
                 chains=2, backend="cmdstanr", normalize=FALSE, threads=4)

  # bundle estimates and helper functions
  ests = list(
    scale = posterior_summary(m_distr, variable="sigma", robust=TRUE)[, "Estimate"],
    df = posterior_summary(m_distr, variable="nu", robust=TRUE)[, "Estimate"],
    natl_sd = posterior_summary(m_distr, variable="sd_year__Intercept", robust=TRUE)[, "Estimate"]
  )
  ests$scale_natl = with(ests, sqrt(scale^2 + natl_sd^2))
  ests$prob_win = function(x) x
  ests$prob_win_logit = function(x) x
  ests$prob_win_natl = function(x) x
  ests$prob_win_natl_logit = function(x) x
  rlang::fn_body(ests$prob_win) = rlang::expr(pt(qlogis(x)/!!ests$scale, df=!!ests$df))
  rlang::fn_body(ests$prob_win_logit) = rlang::expr(pt(x/!!ests$scale, df=!!ests$df))
  rlang::fn_body(ests$prob_win_natl) = rlang::expr(pt(qlogis(x)/!!ests$scale_natl, df=!!ests$df))
  rlang::fn_body(ests$prob_win_natl_logit) = rlang::expr(pt(x/!!ests$scale_natl, df=!!ests$df))

  write_rds(ests, "data/election_model.rds")

  # Appendix plots -------
  d_re = ranef(m_distr)$district[, , 1] |>
    as.data.frame() |>
    rownames_to_column("distr") |>
    as_tibble()
  d_est_2010 = d_re |>
    separate(distr, c("state", "cd_2010", "decade")) |>
    mutate(cd_2010 = as.integer(cd_2010)) |>
    filter(decade == "2010") |>
    select(state, cd_2010, est=Estimate, low=Q2.5, high=Q97.5)
  d_base_2010 = map_dfr(unique(d_est_2010$state), function(abbr) {
    filter(d_baseline, state == abbr) |>
      group_by(state, cd_2010 = as.integer(load_50state_map(abbr)$cd_2010)) |>
      summarize(ldem = log(sum(ndv)) - log(sum(nrv)))
  })
  inner_join(d_est_2010, d_base_2010, by=c("state", "cd_2010")) |>
  ggplot(aes(ldem, est, ymin=low, ymax=high)) +
    geom_abline(slope=1, color="red") +
    geom_errorbar(alpha=0.5, size=0.3) +
    geom_point(size=0.6) +
    labs(x="Precinct-result baseline (logit scale)",
         y="Fitted district-decade random effect (logit scale)") +
    theme_bw(base_size=10, base_family="Arial") +
    theme(plot.margin=margin())
  ggsave("paper/figures/model_ranef.pdf", width=7, height=4.5, device=cairo_pdf)
  inner_join(d_est_2010, d_base_2010, by=c("state", "cd_2010")) |>
    summarize(outside = sum(ldem > high | ldem < low))

  # text example
  filter(d_baseline, state == "GA") |>
    group_by(state, cd_2020 = as.integer(load_50state_map("GA")$cd_2020)) |>
    summarize(ldem = log(sum(ndv)) - log(sum(nrv)))
  ldem = -0.247
  1 - ests$prob_win_natl_logit(ldem)
  1 - ests$prob_win_logit(ldem + 0.2)

  # 2018 validation
  d_act_2018 = filter(d_house, year == 2018) |>
    select(-state, -state_yr) |>
    separate(district, c("state", "cd_2010", "decade")) |>
    mutate(cd_2010 = as.integer(cd_2010))
  d_valid_2018 = inner_join(d_act_2018, d_base_2010, by=c("state", "cd_2010")) |>
    rename(ldem_pred = ldem,
           ldem = ldshare) |>
    mutate(resid = ldem - ldem_pred,
           natl_shift = mean(resid),
           q_resid = cume_dist(resid),
           pr_natl = ests$prob_win_natl_logit(ldem_pred),
           pr = ests$prob_win_logit(ldem_pred + natl_shift),
           pit_natl = pt(resid/ests$scale, df=ests$df),
           pit = pt((resid - natl_shift)/ests$scale, df=ests$df))

  N_draw = 10000
  X = matrix(rt(nrow(d_valid_2018)*N_draw, df=ests$df), nrow=N_draw)
  X = X + rnorm(N_draw, 0, ests$natl_sd)
  X = X + rep(d_valid_2018$ldem_pred, each=N_draw)
  d_sim = tibble(draw = 1:N_draw,
                 dem_natl = rowSums(X > 0),
                 dem = map_int(draw, ~ sum(rbernoulli(nrow(d_valid_2018), d_valid_2018$pr))))
  p1 = ggplot(d_sim, aes(dem_natl)) +
    geom_histogram(aes(y=after_stat(density)), binwidth=1, fill="#444444") +
    geom_vline(xintercept=sum(d_valid_2018$ldem > 0), color="#0064B0", size=1) +
    scale_y_continuous("Density", expand=expansion(mult=c(0, 0.02)), limits=c(0, 0.089)) +
    labs(x="Dem. seats (averaging over nat'l environment)") +
    xlim(80, 175) +
    theme_bw(base_size=10, base_family="Arial")
  p2 = ggplot(d_sim, aes(dem)) +
    geom_histogram(aes(y=after_stat(density)), binwidth=1, fill="#444444") +
    geom_vline(xintercept=sum(d_valid_2018$ldem > 0), color="#0064B0", size=1) +
    scale_y_continuous("Density", expand=expansion(mult=c(0, 0.02)), limits=c(0, 0.089)) +
    labs(x="Dem. seats (conditioning on nat'l environment)") +
    xlim(80, 175) +
    theme_bw(base_size=10, base_family="Arial")
  p1 + p2
  ggsave("paper/figures/model_valid_18.pdf", width=8, height=3.5, device=cairo_pdf)


  library(tidybayes)
  gather_draws(m_distr, r_year[year, var2]) |>
  ggplot(aes(year, .value)) +
    geom_hline(yintercept=0, lty="dashed") +
    stat_gradientinterval(fill_type = "gradient") +
    labs(x="Election", y="Year effect (logit scale)") +
    theme_bw(base_size=10, base_family="Arial") +
    theme(plot.margin=margin())
  ggsave("paper/figures/model_yr_est.pdf", width=7, height=3.5, device=cairo_pdf)

  brms::posterior_summary(m_distr, variable=c("sigma", "sd_year__Intercept", "nu")) |>
    as.data.frame() |>
    `rownames<-`(c("$\\sigma_\\varepsilon$", "$\\sigma_\\beta$", "$\\nu$")) |>
    select(Estimate, `Std. Error`=`Est.Error`) |>
    kableExtra::kbl(digits=3, format = "latex", escape = FALSE, booktabs = TRUE) |>
    write_lines("paper/model_ests.tex")

}

